// STL
#include <fstream>  // std::ifstream
#include <iostream> // std::cerr
#include <string>   // std::string
#include <tuple>    // std::tuple<>
#include <vector>   // std::vector<>

// Boost
#include <boost/program_options.hpp> // boost::program_options


//
// DESC: Read the parameters file.
//
std::tuple<
           std::string,         // dbn_file
           // =====
           bool,                // create
           // -----
           std::vector<int>,    // create_nunits
           std::vector<int>,    // create_units_type
           // =====
           bool,                // train
           // -----
           std::string,         // train_X_file
           // -----
           int,                 // train_nepoch
           int,                 // train_npts_per_batch
           // -----
           int,                 // train_K_start
           int,                 // train_K_end
           double,              // train_K_rate
           // -----
           bool,                // train_sample_v
           bool,                // train_sample_hdata
           // -----
           double,              // train_lr
           std::vector<double>, // train_lr_dot
           std::vector<double>, // train_lr_adj
           double,              // train_lr_min
           double,              // train_lr_max
           // -----
           double,              // train_p_start
           double,              // train_p_end
           double,              // train_p_rate
           std::vector<double>, // train_p_dot
           std::vector<double>, // train_p_adj
           // -----
           double,              // train_w_penalty
           // -----
           int,                 // train_convg_nno_improvement
           double,              // train_convg_max_inc_max_w
           // =====
           bool,                // test
           // -----
           std::string,         // test_X_file
           std::string          // test_H_file
           > read_param_file(
                             const std::string param_filename
                             )
{
    namespace po = boost::program_options;

    //=========================================================

    // ----- PARAMETERS -----
    std::string         dbn_file;            // required
    // =====
    bool                create               = false;
    // -----
    std::vector<int>    create_nunits;       // no default
    std::vector<int>    create_units_type;   // no default
    // =====
    bool                train                = false;
    // -----
    std::string         train_X_file;        // required for training
    // -----
    int                 train_nepoch         = 5000;
    int                 train_npts_per_batch = 10;     // should be between 10 and 100 (suggested 10) --- Hinton, practical guide
    // -----
    int                 train_K_start        = 1;
    int                 train_K_end          = 4;
    double              train_K_rate         = 0.01;
    // -----
    bool                train_sample_v       = false;
    bool                train_sample_hdata   = true;   // true is closer to the mathematical model of an RBM; false allows slightly faster training
    // -----
    double              train_lr             = 0.01;   // Hinton's practical guide suggests 10e-3*|weights| TODO: check and update
    std::vector<double> train_lr_dot;        // default set below TODO: NOTE: not sure how to handle this and lr_adj, p_dot, and p_adj (if the user doesn't want this, they must set one size and adj=1)
    std::vector<double> train_lr_adj;        // default set below
    double              train_lr_min         = 0.0001;
    double              train_lr_max         = 1.;
    // -----
    double              train_p_start        = 0.1;    // suggested initial 0.5, then increase to 0.9 (Hinton, practical guide); 0.1 (start) to 0.9 (end) suggested elsewhere (Masters, DBNs in C++)
                                                       // NOTE: ... and too large initial (without _dot and _adj below seem to lead to instabilities)
    double              train_p_end          = 0.9;
    double              train_p_rate         = 0.01;
    std::vector<double> train_p_dot;         // default set below
    std::vector<double> train_p_adj;         // default set below
    // -----
    double              train_w_penalty      = 0.0001; // 0.0001 is a good default value (Masters, DBNs in C++)
    // -----
//    int free_energy_npts   = 100; // this is the same (number) for both the training subset and the validation data
//    int free_energy_nepoch = 5;   // this only needs to be calculated every few epochs
//    int free_energy_window = 200; // calculate the slope of the free energy over this window
    // -----
    int                 train_convg_nno_improvement = 500;     // 500 is a good default value (Masters, DBNs in C++)
    double              train_convg_max_inc_max_w   = 0.00001; // 0.00001 is a good default value (Masters, DBNs in C++)
    // =====
    bool                test                 = false;
    // -----
    std::string         test_X_file;         // required for testing
    std::string         test_H_file;         // required for testing
    
    
    try
    {
        //=========================================================
        // SET THE OPTIONS
        //=========================================================

        po::options_description desc("options");

        desc.add_options()
        ("dbn_file",             po::value<std::string>(&dbn_file)->required(),   "DBN filename (input or output)")
        // =====
        ("create.nunits",        po::value<std::vector<int>>(&create_nunits),     "[create] number of units")
        ("create.units_type",    po::value<std::vector<int>>(&create_units_type), "[create] type of units")
        // =====
        ("train.X_file",         po::value<std::string>(&train_X_file),           "[train] X filename")
        // -----
        ("train.nepoch",         po::value<int>(&train_nepoch),                   "[train] number of epochs")
        ("train.npts_per_batch", po::value<int>(&train_npts_per_batch),           "[train] number of pointe per batch")
        // -----
        ("train.K_start",        po::value<int>(&train_K_start),                  "[train] K start")
        ("train.K_end",          po::value<int>(&train_K_end),                    "[train] K end")
        ("train.K_rate",         po::value<double>(&train_K_rate),                "[train] K rate")
        // -----
        ("train.sample_v",       po::value<bool>(&train_sample_v),                "[train] sample the visible units")
        ("train.sample_hdata",   po::value<bool>(&train_sample_hdata),            "[train] sample the hidden units for h_data")
        // -----
        ("train.lr",             po::value<double>(&train_lr),                    "[train] learning rate")
        ("train.lr_dot",         po::value<std::vector<double>>(&train_lr_dot),   "[train] learning rate adjustment dot")
        ("train.lr_adj",         po::value<std::vector<double>>(&train_lr_adj),   "[train] learning rate adjustment")
        ("train.lr_min",         po::value<double>(&train_lr_min),                "[train] minimum learning rate")
        ("train.lr_max",         po::value<double>(&train_lr_max),                "[train] maximum learning rate")
        // -----
        ("train.p_start",        po::value<double>(&train_p_start),               "[train] starting momentum")
        ("train.p_end",          po::value<double>(&train_p_end),                 "[train] ending momentum")
        ("train.p_rate",         po::value<double>(&train_p_rate),                "[train] momentum rate")
        ("train.p_dot",          po::value<std::vector<double>>(&train_p_dot),    "[train] learning rate adjustment dot")
        ("train.p_adj",          po::value<std::vector<double>>(&train_p_adj),    "[train] learning rate adjustment")
        // -----
        ("train.w_penalty",      po::value<double>(&train_w_penalty),             "[train] weight penalty")
        // -----
        ("train.convg_nno_improvement", po::value<int>(&train_convg_nno_improvement),  "[train] convergence; number of no improvement")
        ("train.convg_max_inc_max_w",   po::value<double>(&train_convg_max_inc_max_w), "[train] convergence; max W_inc / max W")
        // =====
        ("test.X_file",          po::value<std::string>(&test_X_file),            "[test] X filename (data)")
        ("test.H_file",          po::value<std::string>(&test_H_file),            "[test] H filename (hidden units)")
        ;

        //=========================================================
        // READ THE FILE
        //=========================================================

        std::ifstream ifs(param_filename, std::ios::in);

        po::variables_map vm;
        po::store(po::parse_config_file(ifs , desc), vm);
        po::notify(vm);

        ifs.close();

        //=========================================================
        // INFER CALCULATION TYPE(S)
        //=========================================================

        // TODO: it would be nice if we could determine create / train / etc. by just checking if a block (e.g., "[train]" was specified)

        // ----- CREATE -----
        if(
           vm.count("create.nunits")     ||
           vm.count("create.units_type")
           )
        {
            create = true;
        }

        // ----- TRAIN -----
        // TODO: NOTE: perhaps just train.X_file could be checked for, since it is required for training
        if(
           vm.count("train.X_file")               ||
           // -----
           vm.count("train.nepoch")               ||
           vm.count("train.train_npts_per_batch") ||
           // -----
           vm.count("train.train_K_start")        ||
           vm.count("train.train_K_end")          ||
           vm.count("train.train_K_rate")         ||
           // -----
           vm.count("train.train_sample_v")       ||
           vm.count("train.train_sample_hdata")   ||
           // -----
           vm.count("train.train_lr")             ||
           vm.count("train.train_lr_dot")         ||
           vm.count("train.train_lr_adj")         ||
           vm.count("train.train_lr_min")         ||
           vm.count("train.train_lr_max")         ||
           // -----
           vm.count("train.train_p_start")        ||
           vm.count("train.train_p_end")          ||
           vm.count("train.train_p_rate")         ||
           vm.count("train.train_p_dot")          ||
           vm.count("train.train_p_adj")          ||
           // -----
           vm.count("train.train_w_penalty")      ||
           // -----
           vm.count("train.train_convg_nno_improvement") ||
           vm.count("train.train_convg_max_inc_max_w")
           )
        {
            train = true;
        }

        // ----- TEST -----
        if(
           vm.count("test.X_file")     ||
           vm.count("test.H_file")
           )
        {
            test = true;
        }
        
        //=========================================================
        // CONSISTENCY / ERROR CHECK
        //=========================================================
        // NOTE: default vectors are also set here

        // TODO: it would be nice to determine interrelations between parameters, rather than all of the messy if checks below
        // TODO: ... but I am not sure that is possible with Boost program_options

        //---------------------------------------------------------
        // CREATE
        //---------------------------------------------------------
        if(create_nunits.size() != create_units_type.size())
        {
            std::cerr << "error in read_param_file(): create_nunits.size() != create_units_type.size()" << '\n';
            //        return false;
            exit(0);
        }

        if(create)
        {
            if(create_nunits.empty())
            {
                std::cerr << "error in read_param_file(): create specified, but create_nunits.empty()" << '\n';
                //        return false;
                exit(0);
            }
        }

        //---------------------------------------------------------
        // TRAIN 
        //---------------------------------------------------------
        if(train)
        {
            if(train_X_file.empty())
            {
                std::cerr << "error in read_param_file(): create specified, but train_X_file.empty()" << '\n';
                //        return false;
                exit(0);
            }
        }
        
        // ----- LEARNING RATE ADJUSTMENTS -----
        if(train_lr_dot.size() != train_lr_adj.size())
        {
            std::cerr << "error in read_param_file(): train_lr_dot.size() != train_lr_adj.size()" << '\n';
            //        return false;
            exit(0);
        }

        if(train_lr_dot.empty())
        {
            train_lr_dot = std::vector<double>(2);
            train_lr_dot[0] = 0.5;
            train_lr_dot[1] = 0.3;
            train_lr_adj = std::vector<double>(2);
            train_lr_adj[0] = 1.05;
            train_lr_adj[1] = 1.01;
        }

        // ----- MOMENTUM ADJUSTMENTS -----
        if(train_p_dot.size() != train_p_adj.size())
        {
            std::cerr << "error in read_param_file(): train_p_dot.size() != train_p_adj.size()" << '\n';
            //        return false;
            exit(0);
        }

        if(train_p_dot.empty())
        {
            train_p_dot = std::vector<double>(1);
            train_p_dot[0] = 0.3;
            train_p_adj = std::vector<double>(1);
            train_p_adj[0] = 1.5;
        }
        
        //---------------------------------------------------------
        // TEST 
        //---------------------------------------------------------
        if(test)
        {
            if(test_X_file.empty())
            {
                std::cerr << "error in read_param_file(): create specified, but test.X_file.empty()" << '\n';
                //        return false;
                exit(0);
            }
            
            if(test_H_file.empty())
            {
                std::cerr << "error in read_param_file(): create specified, but test.H_file.empty()" << '\n';
                //        return false;
                exit(0);
            }
        }
    }
    catch(std::exception& e)
    {
        std::cerr << "error in read_param_file_0(): " << e.what() << '\n';
//        return false;
        exit(0);
    }
    catch(...)
    {
        std::cerr << "error in read_param_file_0(): unknown" << '\n';
//        return false;
        exit(0);
    }

    return std::make_tuple(
                           dbn_file,
                           // =====
                           create,
                           // -----
                           create_nunits,
                           create_units_type,
                           // =====
                           train,
                           // -----
                           train_X_file,
                           // -----
                           train_nepoch,
                           train_npts_per_batch,
                           // -----
                           train_K_start,
                           train_K_end,
                           train_K_rate,
                           // -----
                           train_sample_v,
                           train_sample_hdata,
                           // -----
                           train_lr,
                           train_lr_dot,
                           train_lr_adj,
                           train_lr_min,
                           train_lr_max,
                           // -----
                           train_p_start,
                           train_p_end,
                           train_p_rate,
                           train_p_dot,
                           train_p_adj,
                           // -----
                           train_w_penalty,
                           // -----
                           train_convg_nno_improvement,
                           train_convg_max_inc_max_w,
                           // =====
                           test,
                           // -----
                           test_X_file,
                           test_H_file
                           );
}
